{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.5 TensorFlow最佳实践样例程序\n",
    "在5.2.1节中已经给出了一个完整的TensorFlow程序来解决MNIST问题。然而如5.3节和5.4节所述，这个程序的可扩展性并不好，而且也没有持久化训练好的模型。结合前两节介绍的变量管理机制和持久化机制，本节给出了一个TensorFlow训练神经网络模型的最佳实践。\n",
    "\n",
    "在本样例程序中，**将训练和测试分成两个独立的程序，这可以使得每个组件更加灵活**，比如训练神经网络的程序可以持续输出训练好的模型，而测试程序可以每隔一段时间检验最新模型的正确率，如果模型效果更好，则将这个模型提供给产品使用。另外因为在训练和测试过程中都会用到前向传播，本样例中还**将前向传播的过程抽象成一个单独的库函数，这样使用更加方便，也可以保证两个过程中使用到的前向传播是一致的。**\n",
    "\n",
    "本节将提供重构之后的程序来解决MNIST问题，重构之后的代码将本拆分为3个程序：\n",
    "- 第一个是mnist_inferrnce.py，它定义了前向传播的过程和神经网络中的参数；\n",
    "- 第二个是mnist_train.py，它定义了神经网络的训练过程；\n",
    "- 第三个是mnist_eval.py，它定义了测试过程。\n",
    "\n",
    "前两个文件见本文件同目录，第三个文件直接给出在本文件中，如下：\n",
    "\n",
    "*（需要说明的是，由于本人电脑是GTX1060 6G版，同时运行train和eval会报错：`InternalError: Blas SGEMM launch failed`。参考[here](https://blog.csdn.net/Vinsuan1993/article/details/81142855)，于是需要使用`gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4) `在初始化Session的时候为其分配固定数量的显存，这里分别分配0.5/0.4）*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-1-6950bd60d591>:54: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../../../datasets/MNIST_data\\train-images-idx3-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting ../../../datasets/MNIST_data\\train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Extracting ../../../datasets/MNIST_data\\t10k-images-idx3-ubyte.gz\n",
      "Extracting ../../../datasets/MNIST_data\\t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From d:\\python3\\tfgpu\\dl+\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-13001\n",
      "After 13001 training step(s), validation accuracy = 0.984\n",
      "INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-18001\n",
      "After 18001 training step(s), validation accuracy = 0.984\n",
      "INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-22001\n",
      "After 22001 training step(s), validation accuracy = 0.986\n",
      "INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-27001\n",
      "After 27001 training step(s), validation accuracy = 0.9856\n",
      "INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-29001\n",
      "After 29001 training step(s), validation accuracy = 0.9856\n",
      "INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-29001\n",
      "After 29001 training step(s), validation accuracy = 0.9856\n",
      "INFO:tensorflow:Restoring parameters from MNIST_model/mnist_model-29001\n",
      "After 29001 training step(s), validation accuracy = 0.9856\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-1-6950bd60d591>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     56\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     57\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'__main__'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 58\u001b[1;33m     \u001b[0mmain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[1;32m<ipython-input-1-6950bd60d591>\u001b[0m in \u001b[0;36mmain\u001b[1;34m(argv)\u001b[0m\n\u001b[0;32m     53\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mmain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margv\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     54\u001b[0m     \u001b[0mmnist\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0minput_data\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread_data_sets\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"../../../datasets/MNIST_data\"\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mone_hot\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 55\u001b[1;33m     \u001b[0mevaluate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmnist\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     56\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     57\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0m__name__\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;34m'__main__'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m<ipython-input-1-6950bd60d591>\u001b[0m in \u001b[0;36mevaluate\u001b[1;34m(mnist)\u001b[0m\n\u001b[0;32m     48\u001b[0m                     \u001b[1;32mreturn\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     49\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 50\u001b[1;33m             \u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msleep\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mEVAL_INTERVAL_SECS\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     51\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     52\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import time\n",
    "import tensorflow as tf\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "\n",
    "# 加载mnist_inference.py和mnist_train.py中定义的常量和函数\n",
    "import mnist_inference\n",
    "import mnist_train\n",
    "\n",
    "\n",
    "# 每10秒加载一次最新模型，并在测试数据上测试最新模型的正确率\n",
    "EVAL_INTERVAL_SECS = 10\n",
    "\n",
    "def evaluate(mnist):\n",
    "    with tf.Graph().as_default() as g:\n",
    "        # 1. 定义神经网络输入输出\n",
    "        x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')\n",
    "        y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')\n",
    "        validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}\n",
    "\n",
    "        # 2. 计算前向传播结果\n",
    "        y = mnist_inference.inference(x, None)\n",
    "        \n",
    "        # 3. 计算前向传播的结果的正确率，如果需要对未知的样例进行分了，那么使用tf.argmax(y, 1)就可以得到预测类别\n",
    "        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n",
    "        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "\n",
    "        # 通过变量重命名的方式来加载模型，这样在前向传播的过程中就不需要调用求滑动平均的\n",
    "        # 函数来获取平均值了。这样就可以完全共用mnist_inference.py 中定义的前向传播过程。\n",
    "        variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)\n",
    "        variables_to_restore = variable_averages.variables_to_restore()\n",
    "        saver = tf.train.Saver(variables_to_restore)\n",
    "        \n",
    "        # 每隔EVAL_INTERVAL_SECS秒调用一次计算正确率的过程，并检测训练过程中的正确率变化\n",
    "        while True:\n",
    "            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)    # 给当前session分配固定数量显存\n",
    "            with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:\n",
    "                # tf.train.get_checkpoint_state函数会通过checkpoint文件自动找到目录中最新的文件名\n",
    "                ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)\n",
    "                if ckpt and ckpt.model_checkpoint_path:\n",
    "                    # 加载模型\n",
    "                    saver.restore(sess, ckpt.model_checkpoint_path)\n",
    "                    # 通过文件名得到模型保存时迭代的轮数\n",
    "                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]\n",
    "                    accuracy_score = sess.run(accuracy, feed_dict=validate_feed)\n",
    "                    print(\"After %s training step(s), validation accuracy = %g\" % (global_step, accuracy_score))\n",
    "                else:\n",
    "                    print('No checkpoint file found')\n",
    "                    return\n",
    "                \n",
    "            time.sleep(EVAL_INTERVAL_SECS)\n",
    "            \n",
    "\n",
    "def main(argv=None):\n",
    "    mnist = input_data.read_data_sets(\"../../../datasets/MNIST_data\", one_hot=True)\n",
    "    evaluate(mnist)\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
